import time
from pathlib import Path

import numpy as np
import torch
from PIL import Image
from diffusers import DDIMInverseScheduler, DPMSolverMultistepInverseScheduler, DDIMScheduler, \
    DPMSolverMultistepScheduler
from torch import nn

from utils_loc import load_depth, load_image, load_512
from utils_loc.inversion_utils import NullInversion, DirectInversion
from utils_loc.layer_utils import set_custom_cur_t, set_custom_invert_mode, get_custom_modules, set_custom_load_mode, \
    reset_custom_cache


class CIE(nn.Module):
    def __init__(self, pipe, config):
        super().__init__()

        self.device = config.device
        self.use_depth = config.sd_version == "depth"
        self.model_key = config.model_key
        self.scheduler_name = config.scheduler_name
        self.config = config
        self.work_dir = Path(config.work_dir).joinpath(f"{config.inversion.method}-{config.method}")
        self.data_dir = Path(config.dataset_path)

        float_precision = config.float_precision
        if float_precision == "fp16":
            self.dtype = torch.float16
        else:
            self.dtype = torch.float32

        # Initialize the components of the pipeline
        self.pipe = pipe
        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.unet = pipe.unet
        self.text_encoder = pipe.text_encoder
        self.invert_steps = config.inversion.steps
        self.generate_steps = self.config.generation.steps
        self.latent_repeat_itr = self.config.generation.cache.latent_repeat_itr

        self.inverse_method = config.inversion.method

        if self.scheduler_name == "DDIM":
            self.invert_scheduler = DDIMInverseScheduler.from_pretrained(self.model_key, subfolder="scheduler")
            self.generate_scheduler = DDIMScheduler.from_pretrained(self.model_key, subfolder="scheduler")

        elif self.scheduler_name == "DPM-Solver":
            self.invert_scheduler = DPMSolverMultistepInverseScheduler.from_pretrained(self.model_key,
                                                                                       subfolder="scheduler",
                                                                                       solver_order=2)
            self.generate_scheduler = DPMSolverMultistepScheduler.from_pretrained(self.model_key, subfolder="scheduler",
                                                                                  solver_order=2)
        else:
            raise ValueError(f"Unsupported scheduler: {self.scheduler_name}")

        self.generate_scheduler.set_timesteps(self.generate_steps)
        self.invert_scheduler.set_timesteps(self.generate_steps)

        self.source_prompts = None
        self.target_prompts = None
        self.img_paths = None

        self.negative_prompt = config.generation.negative_prompt
        self.guidance_scale = config.generation.guidance_scale
        self.custom_models = get_custom_modules(pipe.unet, config)

    def set_prompts_and_paths(self, source_prompts, target_prompts, img_paths):
        """
        Set the source and target prompts and original image paths.
        :param source_prompts: List of source prompts.
        :param target_prompts: List of target prompts.
        :param origanal_img_paths: List of original image paths.
        """
        self.source_prompts = source_prompts
        self.target_prompts = target_prompts
        self.img_paths = img_paths

    @torch.no_grad()
    def encode_image_to_latents(self, image):
        with torch.autocast(device_type=self.device, dtype=self.dtype):
            image = 2.0 * image - 1.0  # Normalize to [-1, 1]
            latents = self.vae.encode(image).latent_dist
            latents = latents.mean * 0.18215
        return latents

    @torch.no_grad()
    def decode_latents(self, latents):
        with torch.autocast(device_type=self.device, dtype=self.dtype):
            latents = 1 / 0.18215 * latents
            images = self.vae.decode(latents).sample
            images = (images / 2 + 0.5).clamp(0, 1)
        return images

    @torch.no_grad()
    def decode_latents_batch(self, latents):
        imgs = []
        batch_latents = latents.split(1, dim=0)
        for latent in batch_latents:
            imgs += [self.decode_latents(latent)]
        imgs = torch.cat(imgs)
        return imgs

    def prepare_depth(self, pipe, imgs):
        depth_dir = self.work_dir.joinpath("depth")
        depth_dir.mkdir(parents=True, exist_ok=True)
        depths = []
        for i in range(imgs.shape[0]):
            depth_path = depth_dir.joinpath(self.img_paths[i].replace(".jpg", "").replace(".png", "")).joinpath(
                "depth.pt")
            img = imgs[i].unsqueeze(0)  # Add batch dimension
            depth = load_depth(pipe, depth_path, img)
            depths.append(depth)
        depths = torch.cat(depths, dim=0)
        return depths

    @torch.no_grad()
    def get_text_embeds(self, prompts, negative_prompt, device="cuda"):
        prompts_embeddings = []
        prompts_negative_embeddings = []
        for prompt in prompts:
            text_input = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length,
                                        truncation=True, return_tensors='pt')
            prompts_embeddings.append(self.text_encoder(text_input.input_ids.to(device))[0])

            uncond_input = self.tokenizer(negative_prompt, padding='max_length',
                                          max_length=self.tokenizer.model_max_length,
                                          return_tensors='pt')
            prompts_negative_embeddings.append(self.text_encoder(uncond_input.input_ids.to(device))[0])

        text_embeddings = torch.cat(prompts_negative_embeddings + prompts_embeddings)
        return text_embeddings

    @torch.no_grad()
    def invert(self, img):
        """
        Invert an image to latent space using the pipeline's UNet and VAE.
        :param img:
        :return:
        """
        # text_embeddings = self.get_text_embeds(["" for _ in range(len(self.target_prompts))], self.negative_prompt)
        text_embeddings = self.get_text_embeds(self.source_prompts, self.negative_prompt)
        text_embeddings = text_embeddings[len(self.source_prompts):, ]  # Use only the half for inversion
        latent = self.encode_image_to_latents(img)
        latent_list = []
        latent_list.append(latent)

        timesteps = self.invert_scheduler.timesteps
        set_custom_invert_mode(self.custom_models, mode="invert")
        reset_custom_cache(self.custom_models)
        for i, t in enumerate(timesteps):
            set_custom_cur_t(self.custom_models, self.invert_steps - i - 1)
            if self.use_depth:
                latent_input = torch.cat([latent, self.depths.to(latent)], dim=1)
            else:
                latent_input = latent
            with torch.autocast(device_type=self.device, dtype=self.dtype):
                model_output = self.unet(latent_input, t, encoder_hidden_states=text_embeddings).sample
            latent = self.invert_scheduler.step(model_output, timestep=t, sample=latent).prev_sample
            latent_list.append(latent)
        set_custom_invert_mode(self.custom_models, mode="generate")

        return latent_list

    @torch.no_grad()
    def ddim_sample(self, latent, text_embeddings):
        """
        Sample an image from the latent space using DDIM.
        :param latent: Latent representation of the image.
        :param text_embeddings: Text embeddings for conditioning.
        :return: Sampled image.
        """
        if self.scheduler_name == "DDIM":
            scheduler = DDIMScheduler.from_pretrained(self.model_key, subfolder="scheduler")
        elif self.scheduler_name == "DPM-Solver":
            scheduler = DPMSolverMultistepScheduler.from_pretrained(self.model_key, subfolder="scheduler",
                                                                    solver_order=2)
        else:
            raise ValueError(f"Unsupported scheduler: {self.scheduler_name}")

        scheduler.set_timesteps(self.generate_steps)
        timesteps = scheduler.timesteps
        latent_list = []
        latent_list.append(latent)
        set_custom_invert_mode(self.custom_models, mode="invert")

        for t in timesteps:
            set_custom_cur_t(self.custom_models, t)
            with torch.autocast(device_type=self.device, dtype=self.dtype):
                if self.use_depth:
                    latent_input = torch.cat([latent, self.depths.to(latent)], dim=1)
                else:
                    latent_input = latent
                noise_pred = self.unet(latent_input, t, encoder_hidden_states=text_embeddings).sample

            latent = scheduler.step(noise_pred, t, latent).prev_sample
            latent_list.append(latent)
        set_custom_invert_mode(self.custom_models, mode="generate")
        latent_list.reverse()
        return latent_list

    @torch.no_grad()
    def generate(self, latent_list, uncond_embeddings=None, noise_loss_list=None):
        text_embeddings = self.get_text_embeds(self.target_prompts, self.negative_prompt)

        set_custom_invert_mode(self.custom_models, mode="generate")

        timesteps = self.generate_scheduler.timesteps
        latent_list.reverse()
        latent = latent_list[0]

        for i, t in enumerate(timesteps):
            set_custom_load_mode(self.custom_models,
                                 load_feature=i <= int(self.generate_steps * self.config.generation.cache.cache_f_t),
                                 load_attn=i <= int(self.generate_steps * self.config.generation.cache.cache_attn_t))
            set_custom_cur_t(self.custom_models, i)
            if i <= self.latent_repeat_itr:
                latent_model_input = torch.cat([latent_list[i], latent])  # for classifier-free guidance
            elif noise_loss_list is not None:
                latent_model_input = torch.cat([latent + noise_loss_list[-1 - i][:1], latent])
            else:
                latent_model_input = torch.cat([latent, latent])
            if self.use_depth:
                latent_model_input = torch.cat(
                    [latent_model_input, self.depths.repeat(2, 1, 1, 1).to(latent_model_input)], dim=1)
            if uncond_embeddings is not None:
                text_embeddings = torch.cat([uncond_embeddings[-1 - i], text_embeddings[1:, ]], dim=0)
            with torch.autocast(device_type=self.device, dtype=self.dtype):
                noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample

            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)

            latent = self.generate_scheduler.step(noise_pred, t, latent).prev_sample
            set_custom_load_mode(self.custom_models, load_feature=False, load_attn=False)
        set_custom_invert_mode(self.custom_models, mode="invert")
        return latent

    def __call__(self):
        # get the files' names form the data paths
        data_paths = []
        for img_path in self.img_paths:
            data_paths.append(self.data_dir.joinpath(img_path))
        imgs = load_image(data_paths).to(self.device)

        if self.use_depth:
            self.depths = self.prepare_depth(self.pipe, imgs)

        # invert the image to noise
        inverted_x = self.invert(imgs)

        uncond_embeddings = None
        noise_loss_list = None
        if self.inverse_method == "ddim":
            _B = len(self.source_prompts)
            latent_list = inverted_x
        elif self.inverse_method == "directinversion":
            _B = len(self.source_prompts)
            if _B != 1:
                raise ValueError("Direct inversion only supports single image inversion for Direct Inversion.")
            direct_inversion = DirectInversion(model=self.pipe,
                                               num_ddim_steps=len(self.generate_scheduler.timesteps),
                                               scheduler=self.generate_scheduler,
                                               depth=self.depths if self.use_depth else None, )

            # Perform direct inversion
            _, _, latent_list, noise_loss_list = direct_inversion.invert(
                image_gt=load_512(data_paths[0].__str__()),
                prompt=[self.source_prompts[0], self.target_prompts[0]],
                guidance_scale=7.5
            )
        elif self.inverse_method == "null-text-inversion":
            null_inversion = NullInversion(model=self.pipe,
                                           num_ddim_steps=len(self.generate_scheduler.timesteps),
                                           scheduler=self.generate_scheduler,
                                           depth=self.depths if self.use_depth else None, )
            _B = len(self.source_prompts)
            if _B != 1:
                raise ValueError("Direct inversion only supports single image inversion for Null Text inversion.")
            _, _, latent_list, uncond_embeddings = null_inversion.invert(
                image_gt=load_512(data_paths[0].__str__()),
                prompt=self.source_prompts[0],
                guidance_scale=self.guidance_scale,
                num_inner_steps=5
            )
        else:
            raise ValueError(f"Unsupported inverse method: {self.inverse_method}")
        # s_time = time.time()
        edited_latents = self.generate(latent_list, uncond_embeddings, noise_loss_list)
        edited_imgs = self.decode_latents_batch(edited_latents)
        # e_time = time.time()
        # print(f"Time cost: {e_time - s_time:.2f}s, GPU memory usage: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f}MB")


        for _b in range(edited_imgs.shape[0]):
            save_path = self.work_dir.joinpath(self.img_paths[_b])
            # make sure the root directory exists
            save_path.parent.mkdir(parents=True, exist_ok=True)
            save_img = Image.fromarray(np.uint8(edited_imgs[_b].cpu().numpy().transpose(1, 2, 0) * 255))
            save_img.save(save_path)
